- 
                Notifications
    You must be signed in to change notification settings 
- Fork 183
Refactor set_multi_step_attn_mask for arbitrary step #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor set_multi_step_attn_mask for arbitrary step #348
Conversation
| WalkthroughRefactors  Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant Caller
  participant Eagle as MegatronEagle
  rect rgba(220,235,245,0.5)
    note over Eagle: set_multi_step_attention_mask(attn_mask, step)
    Caller->>Eagle: set_multi_step_attention_mask(attn_mask, step)
    loop iter = 2 .. step
      Eagle->>Eagle: build zero_mask, mask_0, mask_1
      Eagle->>Eagle: concat per-iteration masks into attn_mask
    end
    Eagle-->>Caller: updated attn_mask
  end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
 Pre-merge checks and finishing touches✅ Passed checks (3 passed)
 ✨ Finishing touches
 🧪 Generate unit tests
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
 ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
 🔇 Additional comments (2)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
317-318: Optional: vectorize the diagonal write to avoid a Python loop.Minor perf tidy-up; reduces Python overhead for large s.
- for i in range(step_idx - 1, s - 1): - mask_1[:, :, i, i] = False + if step_idx - 1 < s - 1: + idx = torch.arange(step_idx - 1, s - 1, device=attn_mask.device) + mask_1[:, :, idx, idx] = False
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- modelopt/torch/speculative/plugins/megatron_eagle.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
309-327: Generalization approach looks good.The iterative block-growth builds an (steps)×(steps) mask and removes the hardcoded branches. Nice cleanup.
Please re-run the sandbox regression with:
- CUDA tensors (to confirm device fix), and
- very short sequences (e.g., s ∈ {1,2,3}) with step ∈ {2,3,4} to ensure no out-of-bounds and correct shapes.
| Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@            Coverage Diff             @@
##             main     #348      +/-   ##
==========================================
- Coverage   73.84%   73.83%   -0.01%     
==========================================
  Files         172      172              
  Lines       17453    17453              
==========================================
- Hits        12888    12887       -1     
- Misses       4565     4566       +1     ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
309-327: Fix device/dtype mismatch and short‑sequence OOB; don’t shadow built‑initer.
- zero_mask/mask_1 are allocated on CPU; concatenating with CUDA attn_mask will crash.
- For short sequences,
mask_0[:, :, iter-2]can index OOB whenstep > s+1.- Unnecessary full clone before slicing; and
itershadows the Python builtin.Apply this diff:
s = attn_mask.shape[-1] -for iter in range(2, step + 1): - # iter starts from 2nd step - zero_mask = torch.ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s).bool() - mask_0 = attn_mask.clone().detach()[:, :, -s:, :] - mask_0[:, :, iter - 2] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] - mask_1 = torch.ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): - mask_1[:, :, i, i] = False +max_step = min(int(step), s + 1) +if max_step != step: + warnings.warn( + f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}." + ) +for step_idx in range(2, max_step + 1): + # step_idx starts from 2nd step + zero_mask = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s) + mask_0 = attn_mask[:, :, -s:, :].clone() + row_idx = step_idx - 2 + if row_idx < s: + mask_0[:, :, row_idx] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s) + for i in range(step_idx - 1, s - 1): + mask_1[:, :, i, i] = FalseOptional follow-up: vectorize the diagonal update later for perf.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
845-849: Avoid in‑place aliasing; use the original mask as RHS.Assigning from
attn_maskinto an overlapping slice is unnecessary and can be error‑prone. Read fromattention_mask(the source you just cloned) for clarity and safety.- attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] + attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- modelopt/torch/speculative/plugins/megatron_eagle.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
57f9143    to
    66f5c67      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
312-314: Nit: rename zero_mask → full_mask for clarity.It’s an all-True mask; the current name is misleading. Optional rename within this scope.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- modelopt/torch/speculative/plugins/megatron_eagle.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
324-326: Concatenation pattern LGTM.The two-part concat per iteration keeps shapes consistent and is readable.
309-320: Cap step to base seq length; avoid OOB and rename loop var.For short sequences and large step,
mask_0[:, :, iter - 2, :]indexes out of bounds. Also, shadowing Python’s built‑initeris undesirable.Apply:
- s = attn_mask.shape[-1] - for iter in range(2, step + 1): + s = attn_mask.shape[-1] + max_step = min(int(step), s + 1) + if max_step != step: + warnings.warn( + f"set_multi_step_attention_mask: capping step from {step} to {max_step} for base seq len s={s}." + ) + for step_idx in range(2, max_step + 1): # iter starts from 2nd step - zero_mask = attn_mask.new_ones( + zero_mask = attn_mask.new_ones( attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s ).bool() - mask_0 = attn_mask.clone().detach()[:, :, -s:, :] - mask_0[:, :, iter - 2, :] = True - mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_0 = attn_mask[:, :, -s:, :].clone() + mask_0[:, :, step_idx - 2, :] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:].clone() - mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() - for i in range(iter - 1, s - 1): + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + for i in range(step_idx - 1, s - 1): mask_1[:, :, i, i] = False
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
f84d09a    to
    b8cc0de      
    Compare
  
    
What does this PR do?
Type of change: refactor
Overview:
Our current set_multi_step_attn_mask function hardcode attention mask for step=2,3,4 and do not support arbitrary step. This PR generalize it to support arbitrary step > 1.
Usage
# Add a code snippet demonstrating how to use thisTesting
Tested attention mask locally and pass the sandbox regression test.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Performance
Reliability
Compatibility